#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/castTypeToEither.h>
#include <Interpreters/castColumn.h>
#include <boost/math/distributions/normal.hpp>
#include <Common/typeid_cast.h>


namespace DB
{

namespace ErrorCodes
{
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
    extern const int BAD_ARGUMENTS;
}


class FunctionTwoSampleProportionsZTest : public IFunction
{
public:
    static constexpr auto POOLED = "pooled";
    static constexpr auto UNPOOLED = "unpooled";

    static constexpr auto name = "proportionsZTest";

    static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionTwoSampleProportionsZTest>(); }

    String getName() const override { return name; }

    size_t getNumberOfArguments() const override { return 6; }
    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {5}; }

    bool useDefaultImplementationForNulls() const override { return false; }
    bool useDefaultImplementationForConstants() const override { return true; }
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }

    static DataTypePtr getReturnType()
    {
        auto float_data_type = std::make_shared<DataTypeNumber<Float64>>();
        DataTypes types(4, float_data_type);

        Strings names{"z_statistic", "p_value", "confidence_interval_low", "confidence_interval_high"};

        return std::make_shared<DataTypeTuple>(std::move(types), std::move(names));
    }

    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
    {
        for (size_t i = 0; i < 4; ++i)
        {
            if (!isUInt(arguments[i].type))
            {
                throw Exception(
                    ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                    "The {}th Argument of function {} must be an unsigned integer.",
                    i + 1,
                    getName());
            }
        }

        if (!isFloat(arguments[4].type))
        {
            throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                "The fifth argument {} of function {} should be a float,",
                arguments[4].type->getName(),
                getName()};
        }

        /// There is an additional check for constancy in ExecuteImpl
        if (!isString(arguments[5].type) || !arguments[5].column)
        {
            throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                "The sixth argument {} of function {} should be a constant string",
                arguments[5].type->getName(),
                getName()};
        }

        return getReturnType();
    }


    ColumnPtr executeImpl(const ColumnsWithTypeAndName & const_arguments, const DataTypePtr &, size_t input_rows_count) const override
    {
        auto arguments = const_arguments;
        /// Only last argument have to be constant
        for (size_t i = 0; i < 5; ++i)
            arguments[i].column = arguments[i].column->convertToFullColumnIfConst();

        static const auto uint64_data_type = std::make_shared<DataTypeNumber<UInt64>>();

        auto column_successes_x = castColumnAccurate(arguments[0], uint64_data_type);
        const auto & data_successes_x = checkAndGetColumn<ColumnVector<UInt64>>(*column_successes_x).getData();

        auto column_successes_y = castColumnAccurate(arguments[1], uint64_data_type);
        const auto & data_successes_y = checkAndGetColumn<ColumnVector<UInt64>>(*column_successes_y).getData();

        auto column_trials_x = castColumnAccurate(arguments[2], uint64_data_type);
        const auto & data_trials_x = checkAndGetColumn<ColumnVector<UInt64>>(*column_trials_x).getData();

        auto column_trials_y = castColumnAccurate(arguments[3], uint64_data_type);
        const auto & data_trials_y = checkAndGetColumn<ColumnVector<UInt64>>(*column_trials_y).getData();

        static const auto float64_data_type = std::make_shared<DataTypeNumber<Float64>>();

        auto column_confidence_level = castColumnAccurate(arguments[4], float64_data_type);
        const auto & data_confidence_level = checkAndGetColumn<ColumnVector<Float64>>(*column_confidence_level).getData();

        String usevar = checkAndGetColumnConst<ColumnString>(*arguments[5].column).getValue<String>();

        if (usevar != UNPOOLED && usevar != POOLED)
            throw Exception{ErrorCodes::BAD_ARGUMENTS,
                "The sixth argument {} of function {} must be equal to `pooled` or `unpooled`",
                arguments[5].type->getName(),
                getName()};

        const bool is_unpooled = (usevar == UNPOOLED);

        auto res_z_statistic = ColumnFloat64::create();
        auto & data_z_statistic = res_z_statistic->getData();
        data_z_statistic.reserve(input_rows_count);

        auto res_p_value = ColumnFloat64::create();
        auto & data_p_value = res_p_value->getData();
        data_p_value.reserve(input_rows_count);

        auto res_ci_lower = ColumnFloat64::create();
        auto & data_ci_lower = res_ci_lower->getData();
        data_ci_lower.reserve(input_rows_count);

        auto res_ci_upper = ColumnFloat64::create();
        auto & data_ci_upper = res_ci_upper->getData();
        data_ci_upper.reserve(input_rows_count);

        auto insert_values_into_result = [&data_z_statistic, &data_p_value, &data_ci_lower, &data_ci_upper](
                                             Float64 z_stat, Float64 p_value, Float64 lower, Float64 upper)
        {
            data_z_statistic.emplace_back(z_stat);
            data_p_value.emplace_back(p_value);
            data_ci_lower.emplace_back(lower);
            data_ci_upper.emplace_back(upper);
        };

        static constexpr Float64 nan = std::numeric_limits<Float64>::quiet_NaN();

        boost::math::normal_distribution<> nd(0.0, 1.0);

        for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
        {
            const UInt64 successes_x = data_successes_x[row_num];
            const UInt64 successes_y = data_successes_y[row_num];
            const UInt64 trials_x = data_trials_x[row_num];
            const UInt64 trials_y = data_trials_y[row_num];
            const Float64 confidence_level = data_confidence_level[row_num];

            const Float64 props_x = static_cast<Float64>(successes_x) / trials_x;
            const Float64 props_y = static_cast<Float64>(successes_y) / trials_y;
            const Float64 diff = props_x - props_y;
            const UInt64 trials_total = trials_x + trials_y;

            if (successes_x == 0 || successes_y == 0 || successes_x > trials_x || successes_y > trials_y || trials_total == 0
                || !std::isfinite(confidence_level) || confidence_level < 0.0 || confidence_level > 1.0)
            {
                insert_values_into_result(nan, nan, nan, nan);
                continue;
            }

            Float64 se = std::sqrt(props_x * (1.0 - props_x) / trials_x + props_y * (1.0 - props_y) / trials_y);

            /// z-statistics
            /// z = \frac{ \bar{p_{1}} - \bar{p_{2}} }{ \sqrt{ \frac{ \bar{p_{1}} \left ( 1 - \bar{p_{1}} \right ) }{ n_{1} } \frac{ \bar{p_{2}} \left ( 1 - \bar{p_{2}} \right ) }{ n_{2} } } }
            Float64 zstat;
            if (is_unpooled)
            {
                zstat = (props_x - props_y) / se;
            }
            else
            {
                UInt64 successes_total = successes_x + successes_y;
                Float64 p_pooled = static_cast<Float64>(successes_total) / trials_total;
                Float64 trials_fact = 1.0 / trials_x + 1.0 / trials_y;
                zstat = diff / std::sqrt(p_pooled * (1.0 - p_pooled) * trials_fact);
            }

            if (unlikely(!std::isfinite(zstat)))
            {
                insert_values_into_result(nan, nan, nan, nan);
                continue;
            }

            // pvalue
            Float64 pvalue = 0;
            Float64 one_side = 1 - boost::math::cdf(nd, std::abs(zstat));
            pvalue = one_side * 2;

            // Confidence intervals
            Float64 d = props_x - props_y;
            Float64 z = -boost::math::quantile(nd, (1.0 - confidence_level) / 2.0);
            Float64 dist = z * se;
            Float64 ci_low = d - dist;
            Float64 ci_high = d + dist;

            insert_values_into_result(zstat, pvalue, ci_low, ci_high);
        }

        return ColumnTuple::create(
            Columns{std::move(res_z_statistic), std::move(res_p_value), std::move(res_ci_lower), std::move(res_ci_upper)});
    }
};


REGISTER_FUNCTION(ZTest)
{
    FunctionDocumentation::Description description = R"(
Returns test statistics for the two proportion Z-test - a statistical test for comparing the proportions from two populations x and y.
The function supports both pooled and unpooled estimation methods for the standard error.
In the pooled version, the two proportions are averaged and only one proportion is used to estimate the standard error.
In the unpooled version, the two proportions are used separately.
    )";
    FunctionDocumentation::Syntax syntax = "proportionsZTest(successes_x, successes_y, trials_x, trials_y, conf_level, pool_type)";
    FunctionDocumentation::Arguments arguments = {
        {"successes_x", "Number of successes in population x.", {"UInt64"}},
        {"successes_y", "Number of successes in population y.", {"UInt64"}},
        {"trials_x", "Number of trials in population x.", {"UInt64"}},
        {"trials_y", "Number of trials in population y.", {"UInt64"}},
        {"conf_level", "Confidence level for the test.", {"Float64"}},
        {"pool_type", "Selection of pooling method for standard error estimation. Can be either 'unpooled' or 'pooled'.", {"String"}}
    };
    FunctionDocumentation::ReturnedValue returned_value = {"Returns a tuple containing: `z_stat` (Z statistic), `p_val` (P value), `ci_low` (lower confidence interval), `ci_high` (upper confidence interval).", {"Tuple(Float64, Float64, Float64, Float64)"}};
    FunctionDocumentation::Examples examples = {
    {
        "Usage example",
        R"(
SELECT proportionsZTest(10, 11, 100, 101, 0.95, 'unpooled');
        )",
        R"(
┌─proportionsZTest(10, 11, 100, 101, 0.95, 'unpooled')───────────────────────────────┐
│ (-0.20656724435948853,0.8363478437079654,-0.09345975390115283,0.07563797172293502) │
└────────────────────────────────────────────────────────────────────────────────────┘
        )"
    }
    };
    FunctionDocumentation::IntroducedIn introduced_in = {22, 3};
    FunctionDocumentation::Category category = FunctionDocumentation::Category::Mathematical;
    FunctionDocumentation documentation = {description, syntax, arguments, returned_value, examples, introduced_in, category};
    factory.registerFunction<FunctionTwoSampleProportionsZTest>();
}

}
